{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Interpretability - Image Explainers\n",
    "\n",
    "In this example, we use LIME and Kernel SHAP explainers to explain the ResNet50 model's multi-class output of an image.\n",
    "\n",
    "First we import the packages and define some UDFs and a plotting function we will need later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from synapse.ml.explainers import *\n",
    "from synapse.ml.onnx import ONNXModel\n",
    "from synapse.ml.opencv import ImageTransformer\n",
    "from synapse.ml.io import *\n",
    "from pyspark.ml import Pipeline\n",
    "from pyspark.sql.functions import *\n",
    "from pyspark.sql.types import *\n",
    "import numpy as np\n",
    "import urllib.request\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "from synapse.ml.core.platform import *\n",
    "\n",
    "\n",
    "vec_slice = udf(\n",
    "    lambda vec, indices: (vec.toArray())[indices].tolist(), ArrayType(FloatType())\n",
    ")\n",
    "arg_top_k = udf(\n",
    "    lambda vec, k: (-vec.toArray()).argsort()[:k].tolist(), ArrayType(IntegerType())\n",
    ")\n",
    "\n",
    "\n",
    "def downloadBytes(url: str):\n",
    "    with urllib.request.urlopen(url) as url:\n",
    "        barr = url.read()\n",
    "        return barr\n",
    "\n",
    "\n",
    "def rotate_color_channel(bgr_image_array, height, width, nChannels):\n",
    "    B, G, R, *_ = np.asarray(bgr_image_array).reshape(height, width, nChannels).T\n",
    "    rgb_image_array = np.array((R, G, B)).T\n",
    "    return rgb_image_array\n",
    "\n",
    "\n",
    "def plot_superpixels(image_rgb_array, sp_clusters, weights, green_threshold=99):\n",
    "    superpixels = sp_clusters\n",
    "    green_value = np.percentile(weights, green_threshold)\n",
    "    img = Image.fromarray(image_rgb_array, mode=\"RGB\").convert(\"RGBA\")\n",
    "    image_array = np.asarray(img).copy()\n",
    "    for (sp, v) in zip(superpixels, weights):\n",
    "        if v > green_value:\n",
    "            for (x, y) in sp:\n",
    "                image_array[y, x, 1] = 255\n",
    "                image_array[y, x, 3] = 200\n",
    "    plt.clf()\n",
    "    plt.imshow(image_array)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a dataframe for a testing image, and use the ResNet50 ONNX model to infer the image.\n",
    "\n",
    "The result shows 39.6% probability of \"violin\" (889), and 38.4% probability of \"upright piano\" (881)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from synapse.ml.io import *\n",
    "\n",
    "image_df = spark.read.image().load(\n",
    "    \"wasbs://publicwasb@mmlspark.blob.core.windows.net/explainers/images/david-lusvardi-dWcUncxocQY-unsplash.jpg\"\n",
    ")\n",
    "display(image_df)\n",
    "\n",
    "# Rotate the image array from BGR into RGB channels for visualization later.\n",
    "row = image_df.select(\n",
    "    \"image.height\", \"image.width\", \"image.nChannels\", \"image.data\"\n",
    ").head()\n",
    "locals().update(row.asDict())\n",
    "rgb_image_array = rotate_color_channel(data, height, width, nChannels)\n",
    "\n",
    "# Download the ONNX model\n",
    "modelPayload = downloadBytes(\n",
    "    \"https://mmlspark.blob.core.windows.net/publicwasb/ONNXModels/resnet50-v2-7.onnx\"\n",
    ")\n",
    "\n",
    "featurizer = (\n",
    "    ImageTransformer(inputCol=\"image\", outputCol=\"features\")\n",
    "    .resize(224, True)\n",
    "    .centerCrop(224, 224)\n",
    "    .normalize(\n",
    "        mean=[0.485, 0.456, 0.406],\n",
    "        std=[0.229, 0.224, 0.225],\n",
    "        color_scale_factor=1 / 255,\n",
    "    )\n",
    "    .setTensorElementType(FloatType())\n",
    ")\n",
    "\n",
    "onnx = (\n",
    "    ONNXModel()\n",
    "    .setModelPayload(modelPayload)\n",
    "    .setFeedDict({\"data\": \"features\"})\n",
    "    .setFetchDict({\"rawPrediction\": \"resnetv24_dense0_fwd\"})\n",
    "    .setSoftMaxDict({\"rawPrediction\": \"probability\"})\n",
    "    .setMiniBatchSize(1)\n",
    ")\n",
    "\n",
    "model = Pipeline(stages=[featurizer, onnx]).fit(image_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predicted = (\n",
    "    model.transform(image_df)\n",
    "    .withColumn(\"top2pred\", arg_top_k(col(\"probability\"), lit(2)))\n",
    "    .withColumn(\"top2prob\", vec_slice(col(\"probability\"), col(\"top2pred\")))\n",
    ")\n",
    "\n",
    "display(predicted.select(\"top2pred\", \"top2prob\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First we use the LIME image explainer to explain the model's top 2 classes' probabilities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lime = (\n",
    "    ImageLIME()\n",
    "    .setModel(model)\n",
    "    .setOutputCol(\"weights\")\n",
    "    .setInputCol(\"image\")\n",
    "    .setCellSize(150.0)\n",
    "    .setModifier(50.0)\n",
    "    .setNumSamples(500)\n",
    "    .setTargetCol(\"probability\")\n",
    "    .setTargetClassesCol(\"top2pred\")\n",
    "    .setSamplingFraction(0.7)\n",
    ")\n",
    "\n",
    "lime_result = (\n",
    "    lime.transform(predicted)\n",
    "    .withColumn(\"weights_violin\", col(\"weights\").getItem(0))\n",
    "    .withColumn(\"weights_piano\", col(\"weights\").getItem(1))\n",
    "    .cache()\n",
    ")\n",
    "\n",
    "display(lime_result.select(col(\"weights_violin\"), col(\"weights_piano\")))\n",
    "lime_row = lime_result.head()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We plot the LIME weights for \"violin\" output and \"upright piano\" output.\n",
    "\n",
    "Green areas are superpixels with LIME weights above 95 percentile."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_superpixels(\n",
    "    rgb_image_array,\n",
    "    lime_row[\"superpixels\"][\"clusters\"],\n",
    "    list(lime_row[\"weights_violin\"]),\n",
    "    95,\n",
    ")\n",
    "plot_superpixels(\n",
    "    rgb_image_array,\n",
    "    lime_row[\"superpixels\"][\"clusters\"],\n",
    "    list(lime_row[\"weights_piano\"]),\n",
    "    95,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Your results will look like:\n",
    "\n",
    "<img src=\"https://mmlspark.blob.core.windows.net/graphics/explainers/image-lime-20210811.png\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we use the Kernel SHAP image explainer to explain the model's top 2 classes' probabilities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "shap = (\n",
    "    ImageSHAP()\n",
    "    .setModel(model)\n",
    "    .setOutputCol(\"shaps\")\n",
    "    .setSuperpixelCol(\"superpixels\")\n",
    "    .setInputCol(\"image\")\n",
    "    .setCellSize(150.0)\n",
    "    .setModifier(50.0)\n",
    "    .setNumSamples(500)\n",
    "    .setTargetCol(\"probability\")\n",
    "    .setTargetClassesCol(\"top2pred\")\n",
    ")\n",
    "\n",
    "shap_result = (\n",
    "    shap.transform(predicted)\n",
    "    .withColumn(\"shaps_violin\", col(\"shaps\").getItem(0))\n",
    "    .withColumn(\"shaps_piano\", col(\"shaps\").getItem(1))\n",
    "    .cache()\n",
    ")\n",
    "\n",
    "display(shap_result.select(col(\"shaps_violin\"), col(\"shaps_piano\")))\n",
    "shap_row = shap_result.head()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We plot the SHAP values for \"piano\" output and \"cell\" output.\n",
    "\n",
    "Green areas are superpixels with SHAP values above 95 percentile.\n",
    "\n",
    "> Notice that we drop the base value from the SHAP output before rendering the superpixels. The base value is the model output for the background (all black) image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_superpixels(\n",
    "    rgb_image_array,\n",
    "    shap_row[\"superpixels\"][\"clusters\"],\n",
    "    list(shap_row[\"shaps_violin\"][1:]),\n",
    "    95,\n",
    ")\n",
    "plot_superpixels(\n",
    "    rgb_image_array,\n",
    "    shap_row[\"superpixels\"][\"clusters\"],\n",
    "    list(shap_row[\"shaps_piano\"][1:]),\n",
    "    95,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Your results will look like:\n",
    "\n",
    "<img src=\"https://mmlspark.blob.core.windows.net/graphics/explainers/image-shap-20210811.png\"/>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}